{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6e68efed",
   "metadata": {},
   "source": [
    "# Relax Python 模块设计概述\n",
    "\n",
    "参考：[Relax Python 模块设计](https://discuss.tvm.apache.org/t/relax-python-module-design/18272) & [pull: 18229](https://github.com/apache/tvm/pull/18229)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45f762ab",
   "metadata": {},
   "source": [
    "随着机器学习模型——尤其是大型语言模型——的规模持续增长，对 ML 编译器运行时与 Python 生态系统深度集成的需求日益增加。像 PyTorch 这样的基于 Python 的框架提供了丰富的算子库，包括通过 {mod}`torch.distributed` 进行分布式通信等功能，这些功能可以在 GPU 和节点之间高效扩展。这些资源已经被广泛采用并得到良好支持，使其成为编译器运行时中重用的理想候选。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e6aa34b",
   "metadata": {},
   "source": [
    "在 TVM 中，计算图使用 Relax 在 IRModules 中描述。虽然 TVMScript 允许使用类似 Python 的语法表达 Relax 函数，但这些函数不能直接在 Python 中执行。要运行 Relax 函数，必须编译整个 IRModule，并通过虚拟机（VM）加载生成的可执行文件。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a40f37b",
   "metadata": {},
   "source": [
    "为了更好地利用 Python 的运行时环境并丰富 TVM 的灵活性，在支持 Python 的平台上的 IRModules 和 TVMScript 中添加对 Python 函数的原生支持。这些 Python 函数——用 `@py_func` 装饰器标记的——可以直接在 Python 中执行，使用标准的 PyTorch 张量作为输入和输出。类似于 Relax 函数，它们表示计算图，但额外的好处是可以直接、逐步地用 Python 执行。与需要在运行前编译的 Relax 函数不同，Python 函数将不会编译，而是可以直接在 Python 环境中运行。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fb0898c",
   "metadata": {},
   "source": [
    "除了重用 Python 和 PyTorch 实现，在 TVMScript 中支持 Python 函数可以显著提升调试体验。传统编译器将计算图视为单一实体，难以检查中间张量值。随着模型复杂性的增加，这一限制变得更加明显。通过 Python 函数，调试就像插入一条 print 语句一样简单。用户还可以快速手动编辑 Python 函数并立即观察结果——极大地改进了开发和调试工作流程。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36fb264f",
   "metadata": {},
   "source": [
    "## 关键设计\n",
    "\n",
    "### 跨层级调用\n",
    "\n",
    "Relax 中的 Python 函数设计为跨层级，意味着它们可以与 Relax 函数、TIR 函数和 TVM 打包函数互操作。这种双向互操作性允许：\n",
    "\n",
    "- 调用 Relax/TIR/打包函数的 Python 函数。\n",
    "- 通过 `R.call_py_func` 调用 Python 函数的 Relax 函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da3427f3",
   "metadata": {},
   "source": [
    "为了支持这一点，使用 DLPack 实现 TVM NDArrays 和 PyTorch 张量之间的无缝转换，使数据能够在不同的运行环境中流动，并最小化开销。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8ec8b45",
   "metadata": {},
   "source": [
    "### 即时编译（JIT）\n",
    "\n",
    "如果 IRModule 包含任何 Python 函数，会使用 JIT 编译延迟 TIR 和 Relax 函数的编译。这意味着：\n",
    "\n",
    "- TVMScript 解析时不会编译 TIR 和 Relax 函数。\n",
    "- 编译仅在实例化 IRModule 时发生，此时：\n",
    "    - TIR 函数会被编译并存储在实例化的模块中。\n",
    "    - 会创建 Relax 虚拟机来执行编译后的 Relax 函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "587382af",
   "metadata": {},
   "source": [
    "这种 JIT 策略允许更灵活的后期修改和与 Python 运行时的集成。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f76bbbd8",
   "metadata": {},
   "source": [
    "## Relax 函数与 Python 函数之间的转换\n",
    "\n",
    "由于 Relax 函数和 Python 函数都描述计算图，引入了一种新的 IRModule 打印器，将 Relax 函数转换为 Python 函数。这允许用户：\n",
    "\n",
    "- 避免手动编写 Python 函数。\n",
    "- 将 Relax IR 转换为可读和可执行的 Python 代码。\n",
    "- 直接在 Python/PyTorch 中调试或部署中间阶段的 Relax 程序。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e41f5f54",
   "metadata": {},
   "source": [
    "在此转换过程中：\n",
    "\n",
    "- 高级 Relax 算子（例如 `R.nn.relu` ）映射到相应的 PyTorch API（例如 `F.relu` ）。\n",
    "- `call_tir` 和 Relax 函数调用通过将 PyTorch 张量转换为/从 DLPack 格式并传递给编译函数来处理。\n",
    "- `call_dps_packed` 通过通过 `tvm.get_global_func` 检索压缩函数并使用 DLPack 包装的张量调用它来执行。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c3fc97a",
   "metadata": {},
   "source": [
    "这一关键特性是这种转换可以在编译过程的任何阶段发生。例如：\n",
    "\n",
    "- 在早期阶段，用户可以将 Relax 函数转换为 Python，以测试 PyTorch 的实现。\n",
    "- 在后期阶段，当模块的大部分内容被降低到 TIR 时，相同的转换允许使用 PyTorch 运行时进行测试或部署。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa1b6583",
   "metadata": {},
   "source": [
    "未来，可能还会使用一些 PyTorch 基础设施（如 FX 或导出的程序）将 Python 函数跟踪回 Relax 函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b30b886d",
   "metadata": {},
   "source": [
    "## 具体实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb0c2a65",
   "metadata": {},
   "source": [
    "通过 `@I.pyfunc` 装饰器和 `BasePyModule` ，在 TVM Relax 中实现了原生 Python 函数支持，这使 TVM 的编译流程与 Python/PyTorch 运行时环境之间能够无缝集成。这一增强功能允许用户直接在 TVMScript 中编写 Python 函数，这些函数可以与 Relax 和 TIR 函数互操作，从而提供增强的调试能力并利用现有的 PyTorch 算子库。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51d870f6",
   "metadata": {},
   "source": [
    "```{admonition} TVMScript 解析器增强\n",
    "\n",
    "- `@I.pyfunc` 装饰器：用于将 Python 函数标记为 IRModule 的集成目标\n",
    "- 双重存储格式：既存储原始字符串表示形式（用于 TVMScript 打印），也捕获 PackedFunc（用于运行时执行）\n",
    "- `ExternFunc` 表示：每个 Python 函数都表示为一个 ExternFunc 节点，节点属性存储源代码和运行时包装器\n",
    "\n",
    "```{admonition} 完整的 BasePyModule 实现\n",
    "- 基于 DLPack 的张量转换：PyTorch 张量和 TVM NDArrays 之间的无缝转换\n",
    "- 跨函数互操作性：Python 函数可以调用 Relax/TIR 函数，反之亦然\n",
    "- JIT 编译：延迟模块实例化时的编译，以支持灵活的后期修改\n",
    "- 动态函数注册：支持运行时添加 Python 函数\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59f9be2a",
   "metadata": {},
   "source": [
    "## 示例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9ff45731",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入 TVM 核心模块\n",
    "import tvm\n",
    "from tvm import relax, tir\n",
    "# 导入 BasePyModule，这是支持 Python 函数的 IRModule 基类\n",
    "from tvm.relax.base_py_module import BasePyModule\n",
    "# 导入 TVM script 相关模块，用于编写 IR、Relax 和 TIR 代码\n",
    "from tvm.script import ir as I, relax as R, tir as T\n",
    "# 导入设备相关模块\n",
    "from tvm.runtime import Device\n",
    "# 导入 PyTorch，用于演示跨框架数据转换\n",
    "import torch\n",
    "\n",
    "\n",
    "# 使用 @I.ir_module 装饰器定义一个 IR 模块，该模块继承自 BasePyModule\n",
    "@I.ir_module\n",
    "class IRModuleWithPyFunc(BasePyModule):\n",
    "    \"\"\"示例 IRModule 包含 Python 函数支持。\n",
    "    基类 BasePyModule 实现了 Python 中的跨函数调用和 JIT 编译逻辑。\n",
    "    只有继承自 BasePyModule 的 IRModule 才允许包含 Python 函数。\n",
    "    \"\"\"\n",
    "\n",
    "    # 使用 @I.pyfunc 装饰器定义一个可以在 Relax 函数中调用的 Python 函数\n",
    "    @I.pyfunc\n",
    "    def python_add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"可以从 Relax 函数调用的 Python 函数。\"\"\"\n",
    "        # 通过 DLPack 将 PyTorch 张量转换为 TVM NDArray\n",
    "        x_tvm = self._convert_pytorch_to_tvm(x)\n",
    "        y_tvm = self._convert_pytorch_to_tvm(y)\n",
    "        \n",
    "        # 调用编译后的 TIR 函数执行加法运算\n",
    "        result = self.call_tir(self.add_tir, [x_tvm, y_tvm], \n",
    "                             out_sinfo=R.Tensor((5,), \"float32\"))\n",
    "        \n",
    "        # 将结果转换回原始格式（PyTorch 张量）\n",
    "        return self._convert_tvm_to_pytorch(result)\n",
    "\n",
    "    # 使用 @T.prim_func 装饰器定义一个 TIR 原语函数\n",
    "    @T.prim_func\n",
    "    def add_tir(\n",
    "        var_x: T.handle,\n",
    "        var_y: T.handle,\n",
    "        var_out: T.handle,\n",
    "    ):\n",
    "        # 匹配缓冲区，将原始句柄绑定到具体的缓冲区描述\n",
    "        x = T.match_buffer(var_x, (5,), \"float32\")\n",
    "        y = T.match_buffer(var_y, (5,), \"float32\")\n",
    "        out = T.match_buffer(var_out, (5,), \"float32\")\n",
    "        \n",
    "        # 实现向量加法运算\n",
    "        for i in range(5):\n",
    "            out[i] = x[i] + y[i]\n",
    "\n",
    "    # 使用 @R.function 装饰器定义一个 Relax 函数\n",
    "    @R.function\n",
    "    def main_relax(x: R.Tensor((5,), \"float32\"), \n",
    "                   y: R.Tensor((5,), \"float32\")) -> R.Tensor((5,), \"float32\"):\n",
    "        # 直接使用 Relax 的内置加法操作\n",
    "        return R.add(x, y)\n",
    "\n",
    "\n",
    "def main():\n",
    "    \"\"\"展示带有 Python 函数支持的 IRModule 的主函数。\"\"\"\n",
    "    # 创建 IRModuleWithPyFunc 实例\n",
    "    module = IRModuleWithPyFunc()\n",
    "    \n",
    "    # 生成用于测试的随机 PyTorch 张量\n",
    "    x_torch = torch.randn(5, dtype=torch.float32)\n",
    "    y_torch = torch.randn(5, dtype=torch.float32)\n",
    "    \n",
    "    # 通过 DLPack 将 PyTorch 张量转换为 TVM NDArray\n",
    "    x_tvm = module._convert_pytorch_to_tvm(x_torch)\n",
    "    y_tvm = module._convert_pytorch_to_tvm(y_torch)\n",
    "    \n",
    "    # 将 TVM NDArray 转换回 PyTorch 张量\n",
    "    x_back = module._convert_tvm_to_pytorch(x_tvm)\n",
    "    y_back = module._convert_tvm_to_pytorch(y_tvm)\n",
    "    \n",
    "    # 执行跨函数调用测试\n",
    "    # 1. 调用 TIR 函数\n",
    "    tir_result = module.call_tir(\"add_tir\", [x_torch, y_torch], \n",
    "                                out_sinfo=R.Tensor((5,), \"float32\"))\n",
    "    # 2. 调用 Relax 函数\n",
    "    relax_result = module.main_relax(x_torch, y_torch)\n",
    "    # 3. 调用 Python 函数\n",
    "    python_result = module.python_add(x_torch, y_torch)\n",
    "    \n",
    "    # 返回模块实例、DLPack 转换结果和跨函数调用结果\n",
    "    return module, (x_torch, y_torch, x_tvm, y_tvm, x_back, y_back), (tir_result, relax_result, python_result)\n",
    "\n",
    "\n",
    "# 当脚本直接运行时执行主函数\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n",
    "\n",
    "\n",
    "\n",
    "# 示例用法与验证代码（当前被注释掉）\n",
    "# result = main()\n",
    "# assert result is not None, \"函数应返回结果\"\n",
    "# module, dlpack_results, cross_call_results = result\n",
    "# assert len(dlpack_results) == 6, \"DLPack 结果应包含 6 个元素\"\n",
    "# assert len(cross_call_results) == 3, \"跨调用结果应包含 3 个元素\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42a362a4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py313",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
